import json
import logging
import os
import sys
import time

from _dynamic_dataset.EvalDataset import EvalDataset
from _dynamic_dataset.EvalSubset import EvalSubset
from _dynamic_dataset.EvalType import EvalType
from utils.datasetUtils import splitTrainDataByNums
from utils.modelUtils import getAvailableDevice, getResNet18
from utils.randomUtils import all_combinations

from _dynamic_dataset.AccCombinationRecord import AccCombinationRecord, WriteRecordsToFile
from _dynamic_dataset.TrainDataset import TrainSubset
from rootPath import project_path
from server_client.Client import Client
from server_client.Server import Server
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import torch.nn.functional as F
'''
实验目的：验证各方数据集在相同大小和相同分布的情况下的真实夏普利值的计算
'''

# 获取当前日期和时间
now = datetime.now()
# 超参数
batch_size = 128
learning_rate = 0.0001
num_epochs = 50
# threshold = args.threshold
# maxNum = args.maxNum

# 读取配置文件
# 基于梯度的夏普利值计算
with open(project_path + "/conf/client.json", 'r') as f:
    clientConf = json.load(f)

with open(project_path + "/conf/server.json", 'r') as f:
    serverConf = json.load(f)

# # 配置日志记录器
# dirPath = project_path+"/experiments/experiments3/mr+dd/sdss/logs"+f"/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}_{now.second:02d}"
# # 创建目录（如果不存在）
# os.makedirs(dirPath, exist_ok=True)
# logging.basicConfig(filename=f'{dirPath}/main.log', level=logging.INFO)
# 数据预处理
train_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# 测试集数据预处理
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 客户端的数据分配创建
numsList = [
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
]
clientNum = len(numsList)
clients = []
print("生成id组合完成")

device = getAvailableDevice()
# 服务器的验证集
# test_set = torchvision.datasets.CIFAR10(root=project_path+'/data', train=False, download=True, transform=test_transform)
# test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

val_set = torchvision.datasets.CIFAR10(root=project_path + '/data', train=False, download=True,
                                       transform=test_transform)
evalDataset = EvalDataset(val_set.data, val_set.targets, 30, 1.00, test_transform)
evalSubDataset = EvalSubset(evalDataset)
eval_loader = DataLoader(evalSubDataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 训练集
train_dataset = torchvision.datasets.CIFAR10(root=project_path + '/data/', train=True, download=True,
                                             transform=train_transform)
model = getResNet18()
datasetIndices = splitTrainDataByNums(orderList=numsList, trainDataset=train_dataset)
for i in range(clientNum):
    trainSubDataset = TrainSubset(train_dataset, datasetIndices[0])
    clients.append(
        Client(clientConf, model, device, DataLoader(trainSubDataset, shuffle=True, batch_size=128, num_workers=2), i))
    print(f"客户端{i}创建成功")
print("客户端全部初始化成功！")

# 生成客户端的组合
clientIds = []
for i in range(clientNum):
    clientIds.append(i)
combinations = all_combinations(clientIds)

# 服务器
server = Server(serverConf, eval_loader, device, model)
# 测试样例总数
test_example_num = len(val_set.targets)
# 记录器
acc_record = AccCombinationRecord()
ignore_rate_record = AccCombinationRecord()
# 开始时间
start_time = time.time()
# 根据客户端聚合数来选择回归比例
sample_ratio_list = [0, 1.0, 1.0, 0.5, 0.1, 0.1]

for epoch in range(1, num_epochs + 1):
    logging.info(f"epoch={epoch}")
    # 客户端训练
    for client in clients:
        client.train(server.global_model)
    # 服务器梯度聚合并计算对应的正确率
    for combination in combinations:
        # 创建客户端子集
        com_clients = []
        for clientId in combination:
            com_clients.append(clients[clientId])
        server.model_aggregate(com_clients, server.sub_model)
        # 在测试集上进行测试
        # 初始化一个空列表用于收集每个批次的预测结果
        batch_results = []
        batch_confidences = []
        correct = 0
        total = 0
        server.sub_model.eval()
        # 更新需要选取的测试用例
        evalSubDataset.updateIndices(EvalType.WEIGHT_CHOOSE_IGNORE_SET, sample_ratio_list[len(com_clients)])
        indices = evalSubDataset.indices
        print(f"待测试的测试用例数{len(indices)}")
        with torch.no_grad():
            for data in server.eval_loader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = server.sub_model(images)
                # 获取最高概率的类别及其置信度

                confidences, predicted = torch.max(F.softmax(outputs, dim=1), 1)
                # 收集这一批次的预测正确性结果
                batch_results.append((predicted == labels).cpu().int())
                batch_confidences.append(confidences)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            # 将收集的所有批次结果转换为一个张量
            acc_results = torch.cat(batch_results, dim=0)
            confidences_result = torch.cat(batch_confidences, dim=0)
            if len(com_clients) >= 3:
                # 按正常情况处理
                evalSubDataset.addConfidenceRecord(confidences_result)
                evalSubDataset.addEvalRecord(acc_results)
        # 真实的正确率
        precision1 = 100.0
        precision2 = 100 * correct / total
        ir = (test_example_num - len(indices)) / test_example_num * 100
        fit_precision = (precision1 * (test_example_num - len(indices)) / test_example_num) + (precision2 * (len(indices)) / test_example_num)
        acc_record.addCombinationAcc(epoch, combination, fit_precision)
        ignore_rate_record.addCombinationAcc(epoch, combination, ir)
        print(f"epoch={epoch},combination={combination},fit_precision={fit_precision}")
    # 更新全局模型、sub_model
    server.model_aggregate(clients, server.global_model)
# 结束时间
end_time = time.time()

